from torch.utils.data import Dataset
import torch
import numpy as np
import random

class SMTWTDataset(Dataset):
    def __init__(self, batch_size, n_jobs, seed=None):
        super(SMTWTDataset, self).__init__()
        self.n_jobs = n_jobs
        self.batch_size = batch_size
        self.seed = seed
        self.set_seed(seed)
        self.data = self.make_dataset() 

    def set_seed(self, seed):
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def make_dataset(self):
        prt = torch.rand(self.batch_size, self.n_jobs, 1)
        due_time_norm = torch.rand(self.batch_size, self.n_jobs, 1)
        due_time = due_time_norm*self.n_jobs
        weights = torch.rand(self.batch_size, self.n_jobs, 1)
        dataset = torch.cat([prt, due_time, weights], dim=-1)
        return dataset
